#!/usr/bin/env python3
"""
Phase 2: CA Decomposition — Trade vs Income Balance
Research question E: How does Z affect CA through net exports vs. income?

Models:
  1. Z → ca_gdp (reproduce baseline)
  2. Z → trade_balance_gdp
  3. Z → income_balance_gdp
  4. Z → trade_balance_gdp with KAOPEN interactions
  5. Z → income_balance_gdp with KAOPEN interactions
  6. OADR → each component (simpler specification)
  7. NFA creditor/debtor split on income channel
  8. Dynamic: OADR × trend on income balance

Output: extensions/output/tables/ca_decomposition.md
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"

OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS


def load_panel():
    path = PROJECT_DIR / "data" / "processed" / "extensions_panel.csv"
    df = pd.read_csv(path)
    df = df[df["year"] <= 2024].copy()
    return df


def stars(p):
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.1:
        return "*"
    return ""


def run_model(df, dep_var, regressors, label):
    """Run PanelGLS and return results dict."""
    cols = [dep_var] + regressors
    sub = df.dropna(subset=cols).copy()
    if len(sub) < 100:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    model = PanelGLS()
    model.fit(sub[dep_var].values, sub[regressors].values,
              sub["iso3"].values, sub["year"].values)
    print(f"\n  {label}")
    model.summary(feature_names=regressors)

    res = model.to_dataframe(feature_names=regressors)
    res["model"] = label
    res["dep_var"] = dep_var
    res["n_obs"] = model.n_obs
    res["n_countries"] = model.n_countries
    res["r_squared"] = model.r_squared
    res["rho"] = model.rho
    return res


def main():
    print("=" * 70)
    print("PHASE 2: CA Decomposition — Trade vs Income Balance")
    print("=" * 70)

    df = load_panel()

    demo_vars = ["Z_1", "Z_2", "Z_3"]
    eba_controls = ["fiscal_bal_gdp", "nfa_gdp_lag", "log_rel_opw", "kaopen"]
    interaction_vars = ["Z_1_x_kaopen", "Z_2_x_kaopen", "Z_3_x_kaopen"]

    # Check available controls
    eba_controls = [c for c in eba_controls if c in df.columns
                    and df[c].notna().sum() > 200]
    interaction_vars = [c for c in interaction_vars if c in df.columns
                        and df[c].notna().sum() > 200]

    all_results = []

    # ── Model 1: Baseline Z → CA (reproduce) ──────────────────────────────
    base_vars = demo_vars + eba_controls
    r = run_model(df, "ca_gdp", base_vars, "M1: Z → CA/GDP (baseline)")
    if r is not None:
        all_results.append(r)

    # ── Model 2: Z → Trade Balance ─────────────────────────────────────────
    r = run_model(df, "trade_balance_gdp", base_vars,
                  "M2: Z → Trade Balance")
    if r is not None:
        all_results.append(r)

    # ── Model 3: Z → Income Balance ────────────────────────────────────────
    r = run_model(df, "income_balance_gdp", base_vars,
                  "M3: Z → Income Balance")
    if r is not None:
        all_results.append(r)

    # ── Model 4: Z → Trade Balance + KAOPEN interactions ───────────────────
    ext_vars = base_vars + interaction_vars
    r = run_model(df, "trade_balance_gdp", ext_vars,
                  "M4: Z → Trade Balance + Z×KAOPEN")
    if r is not None:
        all_results.append(r)

    # ── Model 5: Z → Income Balance + KAOPEN interactions ──────────────────
    r = run_model(df, "income_balance_gdp", ext_vars,
                  "M5: Z → Income Balance + Z×KAOPEN")
    if r is not None:
        all_results.append(r)

    # ── Model 6: OADR → each component ────────────────────────────────────
    simple_controls = [c for c in eba_controls if c != "kaopen"]
    simple_controls = ["old_dep"] + simple_controls

    for dep, label in [("ca_gdp", "M6a: OADR → CA"),
                       ("trade_balance_gdp", "M6b: OADR → Trade Bal"),
                       ("income_balance_gdp", "M6c: OADR → Income Bal")]:
        r = run_model(df, dep, simple_controls, label)
        if r is not None:
            all_results.append(r)

    # ── Model 7: NFA creditor/debtor split on income channel ───────────────
    # Create NFA-income interactions
    df["nfa_pos_x_ib"] = df["nfa_positive"] * df["income_balance_gdp"]
    creditor_vars = base_vars + ["nfa_positive", "nfa_negative"]
    # Replace nfa_gdp_lag with split
    creditor_vars = [v for v in creditor_vars if v != "nfa_gdp_lag"]

    r = run_model(df, "income_balance_gdp", creditor_vars,
                  "M7: NFA split → Income Balance")
    if r is not None:
        all_results.append(r)

    # ── Model 8: Dynamic — OADR × trend on income ─────────────────────────
    dynamic_vars = ["old_dep", "time_trend", "oadr_x_trend", "old_dep_sq"]
    dynamic_vars = dynamic_vars + [c for c in eba_controls if c != "kaopen"]

    r = run_model(df, "income_balance_gdp", dynamic_vars,
                  "M8: OADR dynamics → Income Balance")
    if r is not None:
        all_results.append(r)

    # ── Combine and save ───────────────────────────────────────────────────
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(OUT_TABLES / "ca_decomposition_results.csv",
                          index=False)

        # Build markdown table
        write_markdown_table(results_df, all_results)

    print("\n  Done. Results saved to extensions/output/tables/")


def write_markdown_table(results_df, all_results):
    """Write formatted markdown results table."""
    models = results_df["model"].unique()
    # Collect unique variables across models
    all_vars_ordered = []
    for r in all_results:
        for v in r["variable"].values:
            if v not in all_vars_ordered:
                all_vars_ordered.append(v)

    lines = ["# CA Decomposition: Trade vs Income Balance", ""]
    lines.append("## Main Results")
    lines.append("")

    # Model comparison summary
    lines.append("| Model | Dep Var | N | Countries | R² | ρ |")
    lines.append("|-------|---------|---|-----------|----|----|")
    for m in models:
        sub = results_df[results_df["model"] == m].iloc[0]
        lines.append(f"| {m} | {sub['dep_var']} | {sub['n_obs']:,} | "
                     f"{sub['n_countries']} | {sub['r_squared']:.3f} | "
                     f"{sub['rho']:.3f} |")

    lines.append("")

    # Key coefficients
    lines.append("## Key Coefficients")
    lines.append("")
    key_vars = ["Z_1", "Z_2", "Z_3", "old_dep", "nfa_gdp_lag",
                "nfa_positive", "nfa_negative", "oadr_x_trend", "old_dep_sq",
                "Z_1_x_kaopen", "Z_2_x_kaopen", "Z_3_x_kaopen"]

    lines.append("| Variable | Model | Coef | SE | p-value |")
    lines.append("|----------|-------|------|----|---------|")
    for v in key_vars:
        sub = results_df[results_df["variable"] == v]
        for _, row in sub.iterrows():
            sig = stars(row["p_value"]) if pd.notna(row["p_value"]) else ""
            lines.append(f"| {v} | {row['model']} | "
                         f"{row['coefficient']:.4f}{sig} | "
                         f"{row['std_error']:.4f} | "
                         f"{row['p_value']:.4f} |")

    md_path = OUT_TABLES / "ca_decomposition.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Markdown table: {md_path}")


if __name__ == "__main__":
    main()
